LlamaIndexを完全に理解するチュートリアル その4:ListIndexで埋め込みベクトルを使用する方法
こんちには。
データアナリティクス事業本部 インテグレーション部 機械学習チームの中村です。
「LlamaIndexを完全に理解するチュートリアル その4」では、GPTListIndexで埋め込みベクトルを使う方法を見ていきます。
本記事で使用する用語は以下のその1で説明していますので、そちらも参照ください。
LlamaIndexを完全に理解するチュートリアル |
---|
その1:処理の概念や流れを理解する基礎編(v0.7.9対応) |
その2:テキスト分割のカスタマイズ |
その3:CallbackManagerで内部動作の把握やデバッグを可能にする |
その4:ListIndexで埋め込みベクトルを使用する方法 |
・本記事の内容はその1のv0.7.9版の記事を投稿後、v0.7.9で動作するように修正しています
本記事の内容
LlamaIndexのGPTListIndexは通常、埋め込みベクトルは使用せず全てのノードを使って処理をシマス。
ただしオプションとしては準備されており、クエリとノードの埋め込みの類似度を求め、使用するノードを決定することは可能となっています。
今回はそれを実現する設定方法を見ていきます。
環境準備
その1と同様の方法で準備します。
使用したバージョン情報は以下となります。
- Python : 3.10.11
- langchain : 0.0.234
- llama-index : 0.7.9
- openai : 0.27.8
サンプルコード
ベースのサンプルは以下とします。ノードの選択状況がわかりやすいよう、LlamaDebugHandlerをCallbackManagerに設定しておきます。
from llama_index import SimpleDirectoryReader from llama_index import ListIndex from llama_index import ServiceContext from llama_index.callbacks import CallbackManager, LlamaDebugHandler documents = SimpleDirectoryReader(input_dir="./data").load_data() llama_debug_handler = LlamaDebugHandler() callback_manager = CallbackManager([llama_debug_handler]) service_context = ServiceContext.from_defaults(callback_manager=callback_manager) list_index = ListIndex.from_documents(documents , service_context=service_context) query_engine = list_index.as_query_engine() response = query_engine.query("機械学習に関するアップデートについて300字前後で要約してください。")
LlamaDebugHandlerの出力ログは以下の通りです。
********** Trace: index_construction |_node_parsing -> 0.028296 seconds |_chunking -> 0.015003 seconds |_chunking -> 0.012298 seconds ********** ********** Trace: query |_query -> 41.653571 seconds |_retrieve -> 0.0 seconds |_synthesize -> 41.653571 seconds |_llm -> 9.249877 seconds |_llm -> 12.309853 seconds |_llm -> 20.035586 seconds **********
デフォルトのListIndexは全ノードを使っていることの確認
使用されるノードの選択状況は、LlamaDebugHandlerのRETRIEVEの結果またはresponse.source_nodes
から把握することができます。
from llama_index.callbacks import CBEventType node_list = llama_debug_handler.get_event_pairs(CBEventType.RETRIEVE)[0][1].payload["nodes"] # node_list = response.source_nodes # こちらでも可 node_count = len(node_list) print(f"{node_count=}") for node in node_list: doc_id = node.node.id_ print(f"{doc_id=}")
node_count=8 doc_id='ba03d238-41ad-4a3a-adfc-3d8ac6cdde48' doc_id='c2cd2c7c-f6a1-4a3a-83d5-cb77e56c56b4' doc_id='80ef0554-de26-4149-985d-cd3594b521e1' doc_id='d8e911e9-e800-424d-9aba-c438e11c11a6' doc_id='2db8f1cf-c766-47d6-88b6-13d71a23e678' doc_id='0810e0ee-b1b5-4d00-9652-4cabbafab5bf' doc_id='8b271e44-b4d8-4b9e-9aa3-a5c6ab88c3f9' doc_id='4d27b006-891e-4fdf-aad6-db6f8237998b'
選ばれたノード数は8個となっています。ListIndexに含まれるdoc_idの情報は以下で取得できます。
for doc_id,v in list_index.storage_context.docstore.docs.items(): print(f"{doc_id=}")
doc_id='ba03d238-41ad-4a3a-adfc-3d8ac6cdde48' doc_id='c2cd2c7c-f6a1-4a3a-83d5-cb77e56c56b4' doc_id='80ef0554-de26-4149-985d-cd3594b521e1' doc_id='d8e911e9-e800-424d-9aba-c438e11c11a6' doc_id='2db8f1cf-c766-47d6-88b6-13d71a23e678' doc_id='0810e0ee-b1b5-4d00-9652-4cabbafab5bf' doc_id='8b271e44-b4d8-4b9e-9aa3-a5c6ab88c3f9' doc_id='4d27b006-891e-4fdf-aad6-db6f8237998b'
一致していることが分かり、現状はListIndexの全てのノードをRETRIEVEで選択していることが分かります。
埋め込みベクトルを使って選択する
ListIndexのデフォルト動作は以上ですが、設定によりノードを埋め込みベクトルの類似度で選択することが可能となります。
そのためには、retriever_modeをListRetrieverMode.EMBEDDINGに設定すればOKです。
similarity_top_kも3に設定し、クエリとの類似度が高い順にノードを3つを選択してみます。
from llama_index import SimpleDirectoryReader from llama_index import ListIndex from llama_index import ServiceContext from llama_index.callbacks import CallbackManager, LlamaDebugHandler from llama_index.indices.list.base import ListRetrieverMode documents = SimpleDirectoryReader(input_dir="./data").load_data() llama_debug_handler = LlamaDebugHandler() callback_manager = CallbackManager([llama_debug_handler]) service_context = ServiceContext.from_defaults(callback_manager=callback_manager) list_index = ListIndex.from_documents(documents , service_context=service_context) query_engine = list_index.as_query_engine( retriever_mode=ListRetrieverMode.EMBEDDING , similarity_top_k=3 ) response = query_engine.query("機械学習に関するアップデートについて300字前後で要約してください。")
LlamaDebugHandlerの出力ログは以下の通りです。
********** Trace: index_construction |_node_parsing -> 0.037637 seconds |_chunking -> 0.016638 seconds |_chunking -> 0.018999 seconds ********** ********** Trace: query |_query -> 20.332703 seconds |_retrieve -> 9.652239 seconds |_embedding -> 0.462239 seconds |_embedding -> 0.205543 seconds |_embedding -> 0.24983 seconds |_embedding -> 7.059133 seconds |_embedding -> 0.237001 seconds |_embedding -> 0.362483 seconds |_embedding -> 0.542566 seconds |_embedding -> 0.218405 seconds |_embedding -> 0.306953 seconds |_synthesize -> 10.679464 seconds |_llm -> 10.652076 seconds **********
RETRIEVE時にEmbeddingが動作していることが分かります。
使用されるノードの選択状況をみてみましょう。
from llama_index.callbacks import CBEventType node_list = llama_debug_handler.get_event_pairs(CBEventType.RETRIEVE)[0][1].payload["nodes"] # node_list = response.source_nodes # こちらでも可 node_count = len(node_list) print(f"{node_count=}") for node in node_list: doc_id = node.node.id_ score = node.score print(f"{doc_id=}, {score=}")
node_count=3 doc_id='bb813454-6ded-4a7c-87b9-b287a2662dd9', score=0.8604478015232282 doc_id='e20291ca-d400-4a4e-a356-c5fd1ca2a973', score=0.841312029237466 doc_id='023656f1-693f-407e-8598-a3e17d0bcc62', score=0.8380146604712428
3つのノードがスコアの高い順に抽出されていることが分かります。
注意点:クエリの都度ノードの埋め込みベクトルを求めてしまう
ListRetrieverMode.EMBEDDINGの場合、求めた埋め込みベクトルはデータストアなどに保存しているわけではないため、
ノード抽出時にその都度埋め込みベクトルを再計算するコストが掛かってしまう点は注意が必要です。
データストアを使用するには、埋め込みベクトルをIndexStoreのノード保存するか、もしくは別のSimpleVectorIndexなどを使う方でも無難に実現できます。
今回はListIndexの範囲に収まる前者の方法を見ていきます。
対策:IndexStoreに埋め込みベクトルを含める方法
まずはベースとなるListIndexを作成します。
from llama_index import SimpleDirectoryReader from llama_index import Document from llama_index import GPTListIndex from llama_index import ServiceContext from llama_index.callbacks import CallbackManager, LlamaDebugHandler, CBEventType from llama_index.indices.list.base import ListRetrieverMode documents = SimpleDirectoryReader(input_dir="./data").load_data() llama_debug_handler = LlamaDebugHandler() callback_manager = CallbackManager([llama_debug_handler]) service_context = ServiceContext.from_defaults(callback_manager=callback_manager) list_index = GPTListIndex.from_documents(documents , service_context=service_context)
そしてノードの一覧を取得して、埋め込みベクトルを求めて格納します。
# 埋め込みベクトルを計算 for doc_id, node in list_index.storage_context.docstore.docs.items(): service_context.embed_model.queue_text_for_embedding( doc_id, node.text ) result_ids, result_embeddings = service_context.embed_model.get_queued_text_embeddings() id_to_embed_map = {} for new_id, text_embedding in zip(result_ids, result_embeddings): id_to_embed_map[new_id] = text_embedding # ノードのembedding属性に埋め込みベクトルを格納 node_list = [] for doc_id, node in list_index.storage_context.docstore.docs.items(): node.embedding = id_to_embed_map[doc_id] node_list.append(node) # 修正したノードでインデックスを再構成 _ = list_index.build_index_from_nodes(node_list)
このようにしておけば、クエリ時に埋め込みベクトルをスキップすることができます。
from llama_index.indices.list.base import ListRetrieverMode query_engine = list_index.as_query_engine( retriever_mode=ListRetrieverMode.EMBEDDING , similarity_top_k=3 ) response = query_engine.query("機械学習に関するアップデートについて300字前後で要約してください。")
********** Trace: query |_query -> 9.887931 seconds |_retrieve -> 0.385121 seconds |_embedding -> 0.3596 seconds |_synthesize -> 9.50281 seconds |_llm -> 9.47453 seconds **********
RETRIEVE時のEMBEDDING処理が1回だけ残っていますが、これはクエリ自体の埋め込みベクトルを求めているため、意図通り動いています。
まとめ
いかがでしたでしょうか。
本記事が、今後LlamaIndexをお使いになられる方の参考になれば幸いです。